import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms


# BaseDataset offer a set of data transform methods and utils
class BaseDataset(Dataset):
    def __init__(self, opt):
        super(BaseDataset, self).__init__()
        self.opt = opt

    def get_params(self, size):
        opt = self.opt
        w, h = size
        new_h = h
        new_w = w
        if opt.preprocess_mode == 'resize_and_crop':
            new_h = new_w = opt.load_size
        elif opt.preprocess_mode == 'scale_width_and_crop':
            new_w = opt.load_size
            new_h = opt.load_size * h // w
        elif opt.preprocess_mode == 'scale_shortside_and_crop':
            ss, ls = min(w, h), max(w, h)  # shortside and longside
            width_is_shorter = w == ss
            ls = int(opt.load_size * ls / ss)
            new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)

        x = random.randint(0, np.maximum(0, new_w - opt.image_size))
        y = random.randint(0, np.maximum(0, new_h - opt.image_size))

        flip = random.random() > 0.5
        return {'crop_pos': (x, y), 'flip': flip}

    def get_transform(self, params, method=Image.BICUBIC, normalize=True, toTensor=True):
        opt = self.opt
        transform_list = []
        if 'resize' in opt.preprocess_mode:
            osize = [opt.load_size, opt.load_size]
            transform_list.append(transforms.Resize(osize, interpolation=method))
        elif 'scale_width' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__scale_width(img, opt.load_size, method)))
        elif 'scale_shortside' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__scale_shortside(img, opt.load_size, method)))

        if 'crop' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__crop(img, params['crop_pos'], opt.image_size)))

        if opt.preprocess_mode == 'none':
            base = 32
            transform_list.append(transforms.Lambda(lambda img: self.__make_power_2(img, base, method)))

        if opt.preprocess_mode == 'fixed':
            w = opt.crop_size
            h = round(opt.crop_size / opt.aspect_ratio)
            transform_list.append(transforms.Lambda(lambda img: self.__resize(img, w, h, method)))

        if self.isTrain() and not opt.no_flip:
            transform_list.append(transforms.Lambda(lambda img: self.__flip(img, params['flip'])))

        if self.isTrain() and 'rotate' in params.keys():
            transform_list.append(transforms.Lambda(lambda img: self.__rotate(img, params['rotate'], method)))

        if toTensor:
            transform_list += [transforms.ToTensor()]

        if normalize:
            transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
                                                    (0.5, 0.5, 0.5))]
        return transforms.Compose(transform_list)

    def isTrain(self):
        return str.strip(self.opt.phase) == 'train'

    def get_sketch_transform(self, params, method=Image.BICUBIC, normalize=True, toTensor=True):
        opt = self.opt
        transform_list = []
        if 'resize' in opt.preprocess_mode:
            osize = [opt.load_size, opt.load_size]
            transform_list.append(transforms.Resize(osize, interpolation=method))
        elif 'scale_width' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__scale_width(img, opt.load_size, method)))
        elif 'scale_shortside' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__scale_shortside(img, opt.load_size, method)))

        if 'crop' in opt.preprocess_mode:
            transform_list.append(transforms.Lambda(lambda img: self.__crop(img, params['crop_pos'], opt.image_size)))

        if opt.preprocess_mode == 'none':
            base = 32
            transform_list.append(transforms.Lambda(lambda img: self.__make_power_2(img, base, method)))

        if opt.preprocess_mode == 'fixed':
            w = opt.crop_size
            h = round(opt.crop_size / opt.aspect_ratio)
            transform_list.append(transforms.Lambda(lambda img: self.__resize(img, w, h, method)))

        if self.isTrain() and not opt.no_flip:
            transform_list.append(transforms.Lambda(lambda img: self.__flip(img, params['flip'])))

        if self.isTrain() and 'rotate' in params.keys():
            transform_list.append(transforms.Lambda(lambda img: self.__rotate(img, params['rotate'], method)))

        if toTensor:
            transform_list += [transforms.ToTensor()]

        if normalize:
            transform_list += [transforms.Normalize(0.5, 0.5)]
        return transforms.Compose(transform_list)

    def __resize(self, img, w, h, method=Image.BICUBIC):
        return img.resize((w, h), method)

    def __make_power_2(self, img, base, method=Image.BICUBIC):
        ow, oh = img.size
        h = int(round(oh / base) * base)
        w = int(round(ow / base) * base)
        if (h == oh) and (w == ow):
            return img
        return img.resize((w, h), method)

    def __scale_width(self, img, target_width, method=Image.BICUBIC):
        ow, oh = img.size
        if (ow == target_width):
            return img
        w = target_width
        h = int(target_width * oh / ow)
        return img.resize((w, h), method)

    def __scale_shortside(self, img, target_width, method=Image.BICUBIC):
        ow, oh = img.size
        ss, ls = min(ow, oh), max(ow, oh)  # shortside and longside
        width_is_shorter = ow == ss
        if (ss == target_width):
            return img
        ls = int(target_width * ls / ss)
        nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
        return img.resize((nw, nh), method)

    def __crop(self, img, pos, size):
        ow, oh = img.size
        x1, y1 = pos
        tw = th = size
        return img.crop((x1, y1, x1 + tw, y1 + th))

    def __flip(self, img, flip):
        if flip:
            return img.transpose(Image.FLIP_LEFT_RIGHT)
        return img

    def __rotate(self, img, deg, method=Image.BICUBIC):
        return img.rotate(deg, resample=method)

    def __add1(self, img):
        return Image.fromarray(np.array(img) + 1)
